import tensorflow as tf
from model_get import *
from utils import *
from sklearn.metrics import roc_auc_score


W2V_PATH = "../data/w2v_model/word2vec_20190626.model"
vocab, embd, vocab_dic = loadWord2Vec(W2V_PATH)
vocab_size = len(vocab)
embedding_dim = len(embd[0])
embedding_placeholder = embd

config = TCNNConfig()

keep_prob = config.keep_prob
config.vocab_size = vocab_size
model = TextCNN(config)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(config.num_epochs):
        for batch_idx in range(config.num_batches):
            input_x, input_y = load_batch("20190626", batch_idx, config.num_classes, config.seq_length, vocab_dic)
            # 训练
            sess.run(model.optim,
                     feed_dict={"input_x:0": input_x, "input_y:0": input_y,
                                "embedding_placeholder:0": embedding_placeholder,
                                "keep_prob:0": config.keep_prob})
            # batch阶段结果显示
            if batch_idx % config.print_per_batch == 0:
                loss = sess.run(model.loss,
                                feed_dict={"input_x:0": input_x, "input_y:0": input_y,
                                           "embedding_placeholder:0": embedding_placeholder,
                                           "keep_prob:0": config.keep_prob})
                acc = sess.run(model.acc,
                               feed_dict={"input_x:0": input_x, "input_y:0": input_y,
                                          "embedding_placeholder:0": embedding_placeholder,
                                          "keep_prob:0": config.keep_prob})
                print("epoch: %s, batch_idx: %s, loss: %s, acc: %s" % (i, batch_idx, loss, acc))
        input_x, input_y = load_batch("20190626", "test", config.num_classes, config.seq_length, vocab_dic)
        acc = sess.run(model.acc,
                       feed_dict={"input_x:0": input_x, "input_y:0": input_y,
                                  "embedding_placeholder:0": embedding_placeholder,
                                  "keep_prob:0": config.keep_prob})
        # epoch阶段结果显示——测试集检测
        input_x, input_y = load_batch("20190626", "test", config.num_classes, config.seq_length, vocab_dic)
        test_prod = sess.run(model.softmax,
                             feed_dict={"input_x:0": input_x, "input_y:0": input_y,
                                        "embedding_placeholder:0": embedding_placeholder,
                                        "keep_prob:0": config.keep_prob})
        test_res = sess.run(model.y_pred_cls,
                            feed_dict={"input_x:0": input_x, "input_y:0": input_y,
                                       "embedding_placeholder:0": embedding_placeholder,
                                       "keep_prob:0": config.keep_prob})
        auc = roc_auc_score([i[1] for i in input_y], [i[1] for i in test_prod])
        p, r, f1score = model_rep([i[1] for i in input_y], test_res)
        print("test epoch: %s, acc: %s, precision: %s, recall: %s. f1: %s, auc: %s" % (i, acc, p, r, f1score, auc))

    # 测试集计算
    input_x, input_y = load_batch("20190626", "test", config.num_classes, config.seq_length, vocab_dic)
    acc = sess.run(model.acc,
                   feed_dict={"input_x:0": input_x, "input_y:0": input_y,
                              "embedding_placeholder:0": embedding_placeholder,
                              "keep_prob:0": config.keep_prob})
    test_prod = sess.run(model.softmax,
                         feed_dict={"input_x:0": input_x, "input_y:0": input_y,
                                    "embedding_placeholder:0": embedding_placeholder,
                                    "keep_prob:0": config.keep_prob})
    test_res = sess.run(model.y_pred_cls,
                        feed_dict={"input_x:0": input_x, "input_y:0": input_y,
                                   "embedding_placeholder:0": embedding_placeholder,
                                   "keep_prob:0": config.keep_prob})
    auc = roc_auc_score([i[1] for i in input_y], [i[1] for i in test_prod])
    p, r, f1score = model_rep([i[1] for i in input_y], test_res)
    print("test result, acc: %s, precision: %s, recall: %s. f1: %s, auc: %s" % (acc, p, r, f1score, auc))
